package pers.vic.boot.security.shiro.extend;

import org.apache.shiro.mgt.SecurityManager;
import org.apache.shiro.spring.web.ShiroFilterFactoryBean;
import org.apache.shiro.subject.ExecutionException;
import org.apache.shiro.subject.Subject;
import org.apache.shiro.web.filter.mgt.FilterChainManager;
import org.apache.shiro.web.filter.mgt.FilterChainResolver;
import org.apache.shiro.web.filter.mgt.PathMatchingFilterChainResolver;
import org.apache.shiro.web.mgt.WebSecurityManager;
import org.apache.shiro.web.servlet.AbstractShiroFilter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.BeanInitializationException;
import pers.vic.boot.base.tool.Tools;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.Callable;

/**
 * 本类的createInstance复制自ShiroFilterFactoryBean中的方法，重写内部类中的doFilterInternal方法， 在更新最后请求时间的加入是否更新的逻辑判断
 *
 * @author Vic.xu
 * @date 2021/08/23
 */
public class CustomerShiroFilterFactoryBean extends ShiroFilterFactoryBean {

    private static transient final Logger log = LoggerFactory.getLogger(CustomerShiroFilterFactoryBean.class);

    /**
     * 忽略更新最后访问时间的Url, 暂定为精确匹配， 毕竟这样的请求一般情况下较少
     */
    private static Set<String> ignoreUpdateLastAccessTimeUrl = new HashSet<String>();

    /**
     * 判断当前请求是否忽略更新最后请求时间
     *
     * @param request
     * @return
     */
    public static boolean willIgnoreUpdateLastAccessTime(ServletRequest request) {
        String url = Tools.getRequestUrl((HttpServletRequest) request);
        return ignoreUpdateLastAccessTimeUrl.contains(url);

    }

    /**
     * @param ignoreUpdateLastAccessTimeUrl
     *            the ignoreUpdateLastAccessTimeUrl to set
     */
    public void setIgnoreUpdateLastAccessTimeUrl(Set<String> ignoreUpdateLastAccessTimeUrl) {
        ignoreUpdateLastAccessTimeUrl.addAll(ignoreUpdateLastAccessTimeUrl);
    }

    /* (non-Javadoc)
     * @see org.apache.shiro.spring.web.ShiroFilterFactoryBean#createInstance()
     */
    @Override
    protected AbstractShiroFilter createInstance() throws Exception {
        // copy from super : super.createInstance() --vic.xu

        log.debug("Creating Shiro Filter instance.");

        SecurityManager securityManager = getSecurityManager();
        if (securityManager == null) {
            String msg = "SecurityManager property must be set.";
            throw new BeanInitializationException(msg);
        }

        if (!(securityManager instanceof WebSecurityManager)) {
            String msg = "The security manager does not implement the WebSecurityManager interface.";
            throw new BeanInitializationException(msg);
        }

        FilterChainManager manager = createFilterChainManager();

        // Expose the constructed FilterChainManager by first wrapping it in a
        // FilterChainResolver implementation. The AbstractShiroFilter implementations
        // do not know about FilterChainManagers - only resolvers:
        PathMatchingFilterChainResolver chainResolver = new PathMatchingFilterChainResolver();
        chainResolver.setFilterChainManager(manager);

        // Now create a concrete ShiroFilter instance and apply the acquired SecurityManager and built
        // FilterChainResolver. It doesn't matter that the instance is an anonymous inner class
        // here - we're just using it because it is a concrete AbstractShiroFilter instance that accepts
        // injection of the SecurityManager and FilterChainResolver:
        return new CustomerSpringShiroFilter((WebSecurityManager) securityManager, chainResolver);
    }

    /**
     * 重写内部类AbstractShiroFilter，为是否进行updateSessionLastAccessTime加一重判断
     *
     * @author Vic.xu
     * @date 2021/08/23
     */
    private static final class CustomerSpringShiroFilter extends AbstractShiroFilter {

        /* (non-Javadoc)
         * @see org.apache.shiro.web.servlet.AbstractShiroFilter#doFilterInternal(javax.servlet.ServletRequest, javax.servlet.ServletResponse, javax.servlet.FilterChain)
         */
        @SuppressWarnings({"unchecked", "rawtypes"})
        @Override
        protected void doFilterInternal(ServletRequest servletRequest, ServletResponse servletResponse,
                                        FilterChain chain) throws ServletException, IOException {
            // copy from super : super.doFilterInternal(servletRequest, servletResponse, chain); -- vic.xu
            Throwable t = null;

            try {
                final ServletRequest request = prepareServletRequest(servletRequest, servletResponse, chain);
                final ServletResponse response = prepareServletResponse(request, servletResponse, chain);

                final Subject subject = createSubject(request, response);

                // noinspection unchecked
                subject.execute(new Callable() {
                    @Override
                    public Object call() throws Exception {

                        /*
                         * 加上是否刷新最后请求时间的逻辑， 本java类唯一需要做的逻辑处理，其他代码均copy自父类
                         */
                        if (!willIgnoreUpdateLastAccessTime(request)) {
                            updateSessionLastAccessTime(request, response);
                        }
                        executeChain(request, response, chain);
                        return null;
                    }
                });
            } catch (ExecutionException ex) {
                t = ex.getCause();
            } catch (Throwable throwable) {
                t = throwable;
            }

            if (t != null) {
                if (t instanceof ServletException) {
                    throw (ServletException) t;
                }
                if (t instanceof IOException) {
                    throw (IOException) t;
                }
                // otherwise it's not one of the two exceptions expected by the filter method signature - wrap it in
                // one:
                String msg = "Filtered request failed.";
                throw new ServletException(msg, t);
            }
        }

        protected CustomerSpringShiroFilter(WebSecurityManager webSecurityManager, FilterChainResolver resolver) {
            super();
            if (webSecurityManager == null) {
                throw new IllegalArgumentException("WebSecurityManager property cannot be null.");
            }
            setSecurityManager(webSecurityManager);

            if (resolver != null) {
                setFilterChainResolver(resolver);
            }
        }
    }

}
